# Copyright 2023 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Make fruitfly model."""

import itertools
import os
from typing import Sequence

from absl import app

from dm_control import mjcf
from dm_control.mujoco.wrapper.mjbindings import mjlib
from lxml import etree
import numpy as np

ASSET_RELPATH = 'assets/'
ASSET_DIR = os.path.dirname(__file__) + '/' + ASSET_RELPATH
BASE_MODEL = 'drosophila_defaults.xml'
FLY_MODEL = 'drosophila_fused.xml'  # Pre-generated by fuse_fruitfly.py

FINAL_MODEL = 'fruitfly.xml'  # Output file of this script.

_YAW_AXIS_PITCH = -47.5 * np.pi / 180

# Empirical mass values, milligrams.
_MASS = {'head': 0.15,
         'thorax': 0.34,
         'abdomen': 0.38,
         'leg': 0.0162,
         'wing': 0.008}


# Utility functions:
def mul_quat(quat_a, quat_b):
  """Returns quat_a * quat_b."""
  quat_c = np.zeros(4)
  mjlib.mju_mulQuat(quat_c, quat_a, quat_b)
  return quat_c


def quat_to_mat(quat):
  """Converts quaternion to rotation matrix."""
  mat = np.zeros(9)
  mjlib.mju_quat2Mat(mat, quat)
  return mat.reshape(3, 3)


def mat_to_quat(mat):
  """Converts rotation matrix to quaternion."""
  quat = np.zeros(4)
  mjlib.mju_mat2Quat(quat, mat.flatten())
  return quat


def neg_quat(quat_a):
  """Returns neg(quat_a) * quat_b."""
  neg_quat_a = quat_a.copy()
  neg_quat_a[0] *= -1
  return neg_quat_a


def rot_vec_quat(vec, quat):
  rot = np.zeros(3)
  mjlib.mju_rotVecQuat(rot, vec, quat)
  return rot


def quat_z2vec(vec):
  """Construct quaternion performing rotation from z-axis to given vector."""
  quat = np.zeros(4)
  mjlib.mju_quatZ2Vec(quat, vec)
  return quat


def change_body_frame(body, frame_pos, frame_quat):
  """Change the frame of a body while maintaining child locations."""
  frame_pos = np.zeros(3) if frame_pos is None else frame_pos
  frame_quat = np.array((1., 0, 0, 0)) if frame_quat is None else frame_quat
  # Get frame transformation.
  body_pos = np.zeros(3) if body.pos is None else body.pos
  dpos = body_pos - frame_pos
  body_quat = np.array((1., 0, 0, 0)) if body.quat is None else body.quat
  dquat = mul_quat(neg_quat(frame_quat), body_quat)
  # Translate and rotate the body to the new frame.
  body.pos = frame_pos
  body.quat = frame_quat
  # Move all its children to their previous location.
  for child in body.all_children():
    if not hasattr(child, 'pos'):
      continue
    # Rotate:
    if hasattr(child, 'quat'):
      child_quat = np.array((1., 0, 0, 0)) if child.quat is None else child.quat
      child.quat = mul_quat(dquat, child_quat)
    # Translate, accounting for rotations.
    child_pos = np.zeros(3) if child.pos is None else child.pos
    pos_in_parent = rot_vec_quat(child_pos, body_quat) + dpos
    child.pos = rot_vec_quat(pos_in_parent, neg_quat(frame_quat))


def main(argv: Sequence[str]):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  print('Load base models.')
  with open(os.path.join(ASSET_DIR, BASE_MODEL), 'r') as f:
    modeltree = etree.XML(f.read(), etree.XMLParser(remove_blank_text=True))
  with open(os.path.join(ASSET_DIR, FLY_MODEL), 'r') as f:
    flytree = etree.XML(f.read(), etree.XMLParser(remove_blank_text=True))

  print('Combine fly model with defaults.')
  worldbody = modeltree.find('worldbody')
  worldbody.addprevious(flytree.find('asset'))

  print('Append root.')
  all_bodies = flytree.xpath('.//body')
  thorax = None
  for body in all_bodies:
    if body.get('name') is not None and 'Armature' == body.get('name'):
      thorax = body
  if thorax is not None:
    worldbody.append(thorax)

  print('Load as mjcf model.')
  model = mjcf.from_xml_string(etree.tostring(modeltree, pretty_print=True),
                               model_dir=ASSET_DIR)

  print('Fix shininess.')
  for material in model.find_all('material'):
    material.shininess = round(material.shininess * 1e6) * 1e-6

  print('Add global cameras.')
  model.worldbody.add('camera', name='hero', pos=(0.271, 0.270, -0.044),
                      xyaxes=(-0.641, 0.767, 0, -0.045, -0.037, 0.998))

  print('Remove _Armature suffix.')
  things = [model.find_all(thing) for thing in ['body', 'joint', 'geom']]
  things = sum(things, [])
  for thing in things:
    if thing.name is not None:
      thing.name = thing.name.replace('_Armature', '')

  print('Remove _body suffix.')
  for geom in model.find_all('geom'):
    if geom.name.endswith('_body'):
      geom.name = geom.name.replace('_body', '')
  for mesh in model.find_all('mesh'):
    if mesh.name.endswith('_body'):
      mesh.name = mesh.name.replace('_body', '')

  print('Rescale mm -> cm, remove (0, 0, 0).')
  for thing in things:
    if thing.pos is not None:
      thing.pos = thing.pos / 10.0
    if np.all(thing.pos == 0):
      thing.pos = None

  print('Use radians.')
  model.compiler.angle = 'radian'

  print('Set autolimits="true".')
  model.compiler.autolimits = 'true'

  print('Remove inertial clauses.')
  for body in model.find_all('body'):
    children = body.all_children()
    for child in children:
      if child.tag == 'inertial':
        child.remove()

  print('Make wings translucent.')
  model.find('material', 'membrane').rgba[3] = 0.4

  print('Get, rename thorax.')
  thorax = model.find('body', 'Armature')
  thorax.name = 'thorax'

  print('Rotate thorax to face positive x-axis.')
  thorax.quat = (0, 0, 0, 1.)

  print('Set thorax mass.')
  thorax.find('geom', 'thorax').mass = _MASS['thorax'] * 1e-3

  print('Rename freejoint.')
  freejoint = thorax.get_children('joint')[0]
  freejoint.remove()
  thorax.insert('freejoint', 0, name='free')

  print('Sort thorax children.')
  sort_order = ['free', 'thorax', 'head', 'wing', 'abdomen',
                'haltere_left', 'haltere_right', 'coxa']
  child_names = [e.name for e in thorax._children]  # pylint: disable=protected-access
  resort = []
  for type_name in sort_order:
    indices = [j for j, s in enumerate(child_names) if type_name in s]
    resort.extend(indices)
  assert len(child_names) == len(resort)
  thorax._children = [thorax._children[i] for i in resort]  # pylint: disable=protected-access

  print('Sort labrums.')
  haustellum = thorax.find('body', 'haustellum')
  sort_order = ['haustellum', 'labrum_left', 'labrum_right']
  child_names = [e.name for e in haustellum._children]  # pylint: disable=protected-access
  resort = []
  for type_name in sort_order:
    indices = [j for j, s in enumerate(child_names) if type_name in s]
    resort.extend(indices)
  assert len(child_names) == len(resort)
  haustellum._children = [haustellum._children[i] for i in resort]  # pylint: disable=protected-access

  print('Retract haustellum.')
  haustellum.find('joint', 'rx_haustellum').range[1] = 0.7
  haustellum.find('joint', 'rx_haustellum').springref = 0.8

  print('Set body childclass.')
  thorax.childclass = 'body'

  print('Remove default-specified values.')
  for joint in thorax.find_all('joint'):
    if joint.tag != 'freejoint' and joint.limited is not None:
      joint.limited = None
  for geom in thorax.find_all('geom'):
    if geom.type is not None and geom.type == 'mesh':
      geom.type = None
    if geom.material is not None and geom.material.name == 'body':
      geom.material = None

  print('Rename leg elements.')
  legs = [body for body in thorax.find_all('body') if 'coxa' in body.name]
  links = ['coxa', 'femur', 'tibia', 'tarsus', 'tarsus2',
           'tarsus3', 'tarsus4']
  sternums = ['T1', 'T2', 'T3']
  sides = ['right', 'left']
  for sternum in sternums:
    for side in sides:
      for leg in legs:
        if sternum in leg.name and side in leg.name:
          body = leg
          for link in links:
            name = link +'_'+ sternum +'_'+ side
            old_name = body.name
            body.name = name
            for geom in body.get_children('geom'):
              if old_name in geom.name:
                geom.name = geom.name.replace(old_name, name)
            for joint in body.get_children('joint'):
              if old_name in joint.name:
                joint.name = joint.name.replace(old_name, name)
            children = body.get_children('body')
            if children:
              body = children[0]

  print('Remove tarsus abductors.')
  tarsi = [body for body in thorax.find_all('body') if 'tarsus' in body.name]
  for tarsus in tarsi:
    for joint in tarsus.get_children('joint'):
      if 'rz' in joint.name and 'tarsus_' not in joint.name:
        joint.remove()

  print('Add claw bodies and joints.')
  tarsi4 = [body for body in thorax.find_all('body') if 'tarsus4' in body.name]
  for tarsus4 in tarsi4:
    gclaw = [geom for geom in tarsus4.get_children('geom')
             if 'claw' in geom.name][0]
    ipos = tarsus4.pos
    claw = tarsus4.add('body', name=tarsus4.name.replace('tarsus4', 'claw'),
                       pos=ipos)
    name = gclaw.name.replace('_brown', '')
    material = gclaw.material
    mesh = gclaw.mesh
    quat = gclaw.quat
    pos = gclaw.pos - ipos
    gclaw.remove()
    claw.add('geom', name=name, material=material, quat=quat, mesh=mesh,
             pos=pos)
    joint_name = name.replace('tarsal_claw', 'rx_tarsus5')
    claw.insert('joint', 0, name=joint_name, axis=(1, 0, 0), pos=(0, 0, 0),
                range=(-.1, .1))

  print('Hair geoms should not contribute mass.')
  black_material = model.find('material', 'black')
  for geom in model.find_all('geom'):
    if geom.material == black_material:
      geom.mass = 0

  print('Symmetrize thorax children.')
  thorax_children = list(thorax.all_children())
  for left in thorax_children:
    if 'left' in left.name:
      right = thorax.find('body', left.name.replace('left', 'right'))
      pos = (right.pos + left.pos * np.array([1, -1, 1]))/2
      right.pos = pos
      left.pos = pos * np.array([1, -1, 1])

  print('Make collision materials.')
  base = model.asset.find('material', 'base')
  base.name = 'blue'
  base.rgba = (0.2, 0.3, 1, 1)
  model.asset.add('material', name='pink', rgba=(0.6, 0.3, 1, 1))

  print('== Infer collision geoms.')
  print('Infer leg collision geoms.')
  legs_meshes = []
  for leg in legs:
    legs_meshes.append(leg.find_all('geom'))
  for leg_meshes in legs_meshes:
    for geom in leg_meshes:
      if 'coxa' in geom.name:
        geom.type = 'ellipsoid'
      else:
        geom.type = 'capsule'

  print('Infer wing collision geoms.')
  wings = [body for body in thorax.find_all('body') if 'wing' in body.name]
  wing_meshes = []
  for wing in wings:
    for geom in wing.find_all('geom'):
      wing_meshes.append(geom)
      geom.type = 'ellipsoid'

  print('Infer mouth collision geoms.')
  haustellum_mesh = thorax.find('geom', 'haustellum')
  haustellum_mesh.type = 'capsule'
  mouth_meshes = [g for g in thorax.find_all('geom') if 'labrum' in g.name]
  for geom in mouth_meshes:
    geom.type = 'ellipsoid'
  mouth_meshes.append(haustellum_mesh)

  print('Recompile.')
  physics = mjcf.Physics.from_mjcf_model(model)

  print('Re-center thorax CoM at origin.')
  offset = physics.bind(thorax).xipos
  offset[1] = 0.0
  for geom in thorax.get_children('geom'):
    geom.pos -= offset
    if geom.quat is not None and geom.quat[0] == 1:
      geom.quat = None
  for body in thorax.get_children('body'):
    body.pos -= offset

  print('Set T3 coxa frame positions.')
  for leg in legs:
    if 'T3' in leg.name:
      pos = np.array((0.021875, 0.01925, -0.04025))
      if 'left' in leg.name:
        pos[1] *= -1
      change_body_frame(leg, pos, leg.quat)
      # reset joint positions
      for joint in leg.get_children('joint'):
        joint.pos = None

  print('Set leg frames, in reverse order.')
  for leg in legs:
    coxa = physics.bind(leg)
    femur = leg.get_children('body')[0]
    tibia = femur.get_children('body')[0]
    tarsus = tibia.get_children('body')[0]
    claw = leg.find('body', leg.name.replace('coxa', 'claw'))
    bound_femur = physics.bind(femur)
    bound_tibia = physics.bind(tibia)
    bound_tarsus = physics.bind(tarsus)
    bound_claw = physics.bind(claw)

    if 'left' in leg.name:
      quat = np.array((0., 1., 0, 0))
    else:
      quat = np.array((1., 0, 0, 0))

    # upper tarsus
    if 'T2' in leg.name or 'T3' in leg.name:
      tarsus_to_claw = bound_claw.xpos - bound_tarsus.xpos
      tarsus_to_tibia = bound_tibia.xpos - bound_tarsus.xpos
      extend = -np.cross(tarsus_to_claw, tarsus_to_tibia)
      extend /= np.linalg.norm(extend)
      twist = tarsus_to_claw
      twist /= np.linalg.norm(twist)
      if 'right' in leg.name:
        twist *= -1
      abduct = np.cross(-extend, twist)
      xmat = np.vstack((-extend, twist, abduct)).T
      mat = bound_tibia.xmat.reshape(3, 3).T @ xmat
      tarsus_pos = tarsus.pos.copy()
      tarsus_pos[2] -= .00175
      tarsus_quat = mat_to_quat(mat)
    else:
      if 'left' in leg.name:
        rquat = np.array((1., 0, 0, 0))
      else:
        rquat = np.array((0., 1., 0, 0))
      tarsus_quat = mul_quat(rquat, tarsus.quat)
    tarsus_pos = tarsus.pos.copy()
    tarsus_pos[2] -= .00175
    change_body_frame(tarsus, tarsus_pos, tarsus_quat)

    # set lower tarsi to match upper tarsus
    parent = tarsus
    children = parent.get_children('body')
    while children:
      child = children[0]
      change_body_frame(child, child.pos, np.array((1., 0, 0, 0)))
      parent = child
      children = parent.get_children('body')

    # tibia
    change_body_frame(tibia, tibia.pos, mul_quat(tibia.quat, quat))

    # femur
    femur_to_tibia = bound_tibia.xpos - bound_femur.xpos
    femur_to_coxa = coxa.xpos - bound_femur.xpos
    extend = np.cross(femur_to_tibia, femur_to_coxa)
    extend /= np.linalg.norm(extend)
    twist = femur_to_tibia
    twist /= np.linalg.norm(twist)
    if 'right' in femur.name:
      twist *= -1
    abduct = np.cross(-extend, twist)
    xmat = np.vstack((-extend, twist, abduct)).T
    mat = coxa.xmat.reshape(3, 3).T @ xmat
    change_body_frame(femur, femur.pos, mat_to_quat(mat))

    # coxa
    twist = -femur_to_coxa
    twist /= np.linalg.norm(twist)
    if 'right' in leg.name:
      twist *= -1
    abduct = np.cross(extend, twist)
    xmat = np.vstack((extend, twist, abduct)).T
    mat = physics.bind(thorax).xmat.reshape(3, 3).T @ xmat
    change_body_frame(leg, leg.pos, mat_to_quat(mat))

  print('Recompile.')
  physics = mjcf.Physics.from_mjcf_model(model)

  print('== Make collision geoms.')
  print('Make leg collision geoms.')
  for leg_meshes in legs_meshes:
    fromto = None
    for geom in leg_meshes:
      pos = physics.bind(geom).pos
      quat = physics.bind(geom).quat
      size = physics.bind(geom).size
      children = geom.parent.get_children('body')
      gtype = geom.type
      if len(children) == 1:
        dclass = 'collision'
        if 'coxa' in geom.name:
          fromto = None
          if 'T3' in geom.name:
            axis = children[0].pos / np.linalg.norm(children[0].pos)
            quat = quat_z2vec(axis)
            size = np.array((0.007, 0.00875, 0.016625))
            pos[1] -= 0.00175 * (1 if 'left' in geom.name else -1)
          elif 'T2' in geom.name:
            axis = children[0].pos / np.linalg.norm(children[0].pos)
            quat = quat_z2vec(axis)
            size = np.array((0.007875, 0.007, 0.014875))
            pos *= 0.7875
          else:
            size *= 1.225
        else:
          quat = None
          pos = None
          from_ = [0, 0, 0]
          if 'femur' in geom.name:
            from_[2] = .004375 * (1 if 'left' in geom.name else -1)
            from_[1] = .002625 * (1 if 'left' in geom.name else -1)
            fromto = np.hstack((from_, children[0].pos*0.95))
            if 'T3' in geom.name:
              fromto[0] += 0.002625 * (-1 if 'left' in geom.name else 1)
              fromto[3] += 0.002625 * (-1 if 'left' in geom.name else 1)
              fromto[4] += 0.002625 * (1 if 'left' in geom.name else -1)
            if 'T1' in geom.name:
              fromto[0] += .002625 * (1 if 'left' in geom.name else -1)
            size = (1.05 * size[0],)
          elif 'tibia' in geom.name and 'T3' in geom.name:
            fromto = np.hstack((from_, children[0].pos))
            fromto[1] += .003 * (-1 if 'left' in geom.name else 1)
            fromto[2] += .005 * (-1 if 'left' in geom.name else 1)
            size = (1.05 * size[0],)
          else:
            fromto = np.hstack((from_, children[0].pos))
            size = (1.225 * size[0],)
          gtype = None
      else:
        dclass = 'adhesion-collision'
        pos = None
        quat = None
        size = (1.225 * size[0],)
      name = geom.name + '_collision'
      index = geom.parent._children.index(geom)  # pylint: disable=protected-access
      geom.parent.insert('geom', index+1, fromto=fromto, size=size, quat=quat,
                         name=name, type=gtype, dclass=dclass, pos=pos)
      geom.type = None

  print('Make wing collision geoms.')
  for geom in wing_meshes:
    pos = physics.bind(geom).pos
    quat = physics.bind(geom).quat
    name = geom.name + '_collision'
    #  Adjust MuJoCo's geom fits.
    if 'membrane' in name:
      size = (0.0030625, 0.055125, 0.11375)
      angle = 0.11 * (1 if 'right' in name else -1)
      lateral = 0.02625 * (-1 if 'right' in name else 1)
      forward = -0.002625
      in_plane = np.array([[forward], [lateral]])
    else:
      size = (0.0021875, 0.0175, 0.11375)
      angle = 0.05 * (1 if 'right' in name else -1)
      lateral = 0.0035 * (-1 if 'right' in name else 1)
      forward = 0.0100625
      in_plane = np.array([[forward], [lateral]])
    mat = quat_to_mat(physics.bind(geom).quat)
    offset = mat[:, 1:3] @ in_plane
    pos = pos + offset.flatten()
    rotate = np.array((np.cos(angle/2), np.sin(angle/2), 0, 0))
    quat = mul_quat(quat, rotate)

    index = geom.parent._children.index(geom)  # pylint: disable=protected-access
    colgeom = geom.parent.insert('geom', index+1, pos=pos, quat=quat, size=size,
                                 name=name, type=geom.type, dclass='collision')
    geom.type = None
    # set wing inertias using a custom geom
    geom.mass = 0
    if 'membrane' in name:
      colgeom.dclass = 'collision-membrane'
      # add fluid-interaction geoms
      gsize = colgeom.size.copy()
      gsize[0] = 0.0005  # 2 microns
      gname = geom.parent.name + '_inertial'
      geom.parent.insert('geom', index+2, pos=colgeom.pos, quat=colgeom.quat,
                         name=gname, dclass='wing-inertial', size=gsize)
      gname = geom.parent.name + '_fluid'
      geom.parent.insert('geom', index+2, pos=colgeom.pos, quat=colgeom.quat,
                         name=gname, dclass='wing-fluid', size=gsize)

  print('Make mouth collision geoms.')
  for geom in mouth_meshes:
    pos = physics.bind(geom).pos
    quat = physics.bind(geom).quat

    #  Adjust MuJoCo's geom fits
    if 'haustellum' in geom.name:
      size = (0.007875, 0.007875)
    else:
      size = (0.0035, 0.00875, 0.013125)
      pos *= 1.22
    # save collision geom
    name = geom.name + '_collision'
    index = geom.parent._children.index(geom)  # pylint: disable=protected-access
    dclass = 'adhesion-collision' if 'labrum' in geom.name else 'collision'
    geom.parent.insert('geom', index+1, pos=pos, quat=quat, size=size,
                       name=name, type=geom.type, dclass=dclass)
    geom.type = None

  print('Make antennae collision geoms.')
  antennae = [b for b in thorax.find_all('body') if 'antenna' in b.name]
  for antenna in antennae:
    pos = (0, 0.011375, 0.002625)
    zaxis = (0, 0.875, -0.175)
    size = (0.0048125, 0.00875)
    name = antenna.name + '_collision'
    quat = np.zeros(4)
    mjlib.mju_quatZ2Vec(quat, np.asarray(zaxis))
    antenna.insert('geom', -1, name=name, type='capsule', dclass='collision',
                   pos=pos, quat=quat, size=size)

  print('Make abdomen collision geoms.')
  abdomens = [b for b in thorax.find_all('body') if 'abdomen' in b.name]
  for abdomen in abdomens:
    name = abdomen.name + '_collision'
    inertia = physics.bind(abdomen).inertia
    mass = physics.bind(abdomen).mass
    pos = physics.bind(abdomen).ipos
    quat = physics.bind(abdomen).iquat
    # Get inertia box
    size = np.zeros(3)
    for i in range(3):
      not_i = set([0, 1, 2]) - set([i])
      accum = 0.0
      for j in not_i:
        accum += inertia[j]
      accum -= inertia[i]
      size[i] = np.sqrt(accum / mass * 6) / 2

    if '7' in abdomen.name:
      size = (0.02625,)
      gtype = 'sphere'
      quat = None
    else:
      radius = size[1:3].max()
      height = size[0]
      # axis = children[0].pos / np.linalg.norm(children[0].pos)
      if '1' in abdomen.name:
        angle = np.pi/2
        rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
        quat = mul_quat(quat, rotate)
        pos[2] -= 0.00525
        height *= 1.5
      else:
        # quat = quat_z2vec(np.array((0, 1., 0)))
        angle = np.pi/2 + 0.1
        rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
        quat = mul_quat(quat, rotate)
        pos[2] -= 0.0875 * size[0]
        radius *= 1.05
      if '3' in abdomen.name:
        pos[2] += 0.0021875
      if '6' in abdomen.name:
        angle = - 0.096
        rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
        quat = mul_quat(quat, rotate)
      size = np.array([radius, height, 0])
      gtype = 'cylinder'
    # Make collision geom.
    abdomen.insert('geom', 4, type=gtype, dclass='collision', size=size,
                   pos=pos, quat=quat, name=name)

  print('Add bespoke collision geoms.')
  # Thorax.
  angle = -1.
  rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
  thorax.insert('geom', 2, name='thorax_collision', dclass='collision',
                type='ellipsoid', size=(0.04375, 0.04375, 0.055125),
                pos=(-0.0175, 0, -0.002625), quat=rotate)
  angle = -.4
  rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
  thorax.insert('geom', 3, name='thorax_collision2', dclass='collision',
                type='ellipsoid', size=(0.049875, 0.028, 0.011375),
                pos=(-0.011375, 0, 0.021875), quat=rotate)
  # Head.
  head = thorax.find('body', 'head')
  head.insert('site', 0, name='head', pos=(0, 0.015, 0), quat=(0, 0, 0, 1))
  head.insert('geom', 4, name='head_collision', dclass='collision',
              type='ellipsoid', size=(0.0455, 0.02625, 0.032375),
              pos=(0, 0.014875, 0.000875), euler=(.3, 0, 0))
  # Eye cameras.
  mat = np.zeros((3, 3))
  mat[:, 0] = [.45, -1, -.3]
  mat[:, 0] /= np.linalg.norm(mat[:, 0])
  mat[:, 1] = [-.2, 0, 1]
  mat[:, 1] /= np.linalg.norm(mat[:, 1])
  mat[:, 1] -= mat[:, 0] * np.dot(mat[:, 0], mat[:, 1])
  mat[:, 1] /= np.linalg.norm(mat[:, 1])
  mat[:, 2] = np.cross(mat[:, 0], mat[:, 1])
  head.insert('camera', 8, name='eye_right', fovy=140,
              pos=(0.021875, 0.013125, 0), quat=mat_to_quat(mat))
  mat[:, 0] = [.45, 1, .3]
  mat[:, 0] /= np.linalg.norm(mat[:, 0])
  mat[:, 1] = [.2, 0, 1]
  mat[:, 1] /= np.linalg.norm(mat[:, 1])
  mat[:, 1] -= mat[:, 0] * np.dot(mat[:, 0], mat[:, 1])
  mat[:, 1] /= np.linalg.norm(mat[:, 1])
  mat[:, 2] = np.cross(mat[:, 0], mat[:, 1])
  head.insert('camera', 9, name='eye_left', fovy=140,
              pos=(-0.021875, 0.013125, 0), quat=mat_to_quat(mat))
  # Rostrum.
  rostrum = thorax.find('body', 'rostrum')
  rostrum.insert('geom', 2, name='rostrum_collision', dclass='collision',
                 type='ellipsoid', size=(0.013125, 0.021875, 0.013125),
                 pos=(0, 0.0175, 0.002625), euler=(.1, 0, 0))
  for side in ['_left', '_right']:
    fromto = np.array((-0.006125, 0.032375, 0, -0.011375, 0.0245, -0.023625))
    if 'r' in side:
      fromto[0] *= -1
      fromto[3] *= -1
    rostrum.insert('geom', 2, name='rostrum_collision'+side, dclass='collision',
                   size=(0.0035,), fromto=fromto)

  print('Recompile.')
  physics = mjcf.Physics.from_mjcf_model(model)

  print('== Finalise joints and defaults.')

  print('Wing joints.')
  # Set pitch joint range.
  pitch_range = 2. / 3. * np.pi * np.array((-1., 1.)) - _YAW_AXIS_PITCH
  pitch_range = np.round(pitch_range*100)/100
  model.find('default', 'pitch').joint.range = pitch_range
  wing_quats = [np.array([np.cos(angle/2), 0, np.sin(angle/2), 0]) for
                angle in [_YAW_AXIS_PITCH, _YAW_AXIS_PITCH + np.pi]]

  # Set wing joints.
  for i, wing in enumerate(wings):
    wing.childclass = model.find('default', 'wing')
    change_body_frame(wing, wing.pos, wing_quats[i])
    for joint in wing.get_children('joint'):
      joint.range = None
      joint.axis = None
      joint.pos = None
      if 'rx' in joint.name:
        joint.name = joint.name.replace('rx', 'yaw')
        joint.dclass = 'yaw'
      if 'ry' in joint.name:
        joint.name = joint.name.replace('ry', 'roll')
        joint.dclass = 'roll'
      if 'rz' in joint.name:
        joint.name = joint.name.replace('rz', 'pitch')
        joint.dclass = 'pitch'

  print('Symmetrize wing geoms children.')
  wing_left = thorax.find('body', 'wing_left')
  wing_right = thorax.find('body', 'wing_right')
  for lgeom in wing_left.find_all('geom'):
    rgeom = wing_right.find('geom', lgeom.name.replace('left', 'right'))
    pos = (lgeom.pos - rgeom.pos)/2
    lgeom.pos = pos
    rgeom.pos = -pos

  print('Symmetrize Antennae left->right.')
  head_xmat = physics.bind(head).xmat.reshape(3, 3)
  left_xmat = physics.bind(antennae[0]).xmat.reshape(3, 3)
  right_xmat = left_xmat * np.array(([-1], [1], [-1]))
  right_mat = head_xmat.T @ right_xmat
  right_quat = mat_to_quat(right_mat)
  change_body_frame(antennae[1], antennae[1].pos, right_quat)

  # Axis names dict:
  ax_names = {'rx': 'extend', 'ry': 'twist', 'rz': 'abduct'}

  print('Reorder all joint axes: (ry, rz, rx).')
  for body in model.find_all('body'):
    joint_index = np.array((-1, -1, -1))
    for joint in body.get_children('joint'):
      for i, axis in enumerate(ax_names):
        if axis in joint.name:
          joint_index[i] = body._children.index(joint)  # pylint: disable=protected-access
    if sum(joint_index >= 0) > 1:
      joint_reorder = [joint_index[i] for i in [2, 1, 0] if joint_index[i] >= 0]
      joint_reorder = list(filter(lambda a: a != -1, joint_reorder))
      joint_index = list(filter(lambda a: a != -1, joint_index))
      child_order = list(range(len(body._children)))  # pylint: disable=protected-access
      for i, index in enumerate(joint_index):
        child_order[index] = joint_reorder[i]
      body._children = [body._children[i] for i in child_order]  # pylint: disable=protected-access

  print('Rename all joints.')
  for joint in model.find_all('joint'):
    for axis in ax_names:
      joint.name = joint.name.replace(axis, ax_names[axis])

  print('Head joints.')
  head_joint_range = {'twist': (-3, 3),
                      'abduct': (-.2, .2),
                      'extend': (-.5, .3)}
  head.childclass = model.find('default', 'head')
  for joint in head.get_children('joint'):
    for joint_range in head_joint_range:
      if joint_range in joint.name:
        joint.range = head_joint_range[joint_range]
  for i, side in enumerate(sides):
    labrum = head.find('body', 'labrum_' + side)
    labrum.childclass = model.find('default', 'labrum')
    pos = labrum.pos - (i*2-1) * np.array((0.002625, 0, 0))
    pos -= np.array((0, 0.002625, 0))
    change_body_frame(labrum, pos, labrum.quat)
    for joint in labrum.get_children('joint'):
      joint.pos = None

  print('Abdominal joints.')
  abdomens[0].childclass = model.find('default', 'abdomen')
  abdomens[0].name = abdomens[0].name.replace('abdomen_1', 'abdomen')
  for child in abdomens[0]._children:  # pylint: disable=protected-access
    child.name = child.name.replace('abdomen_1', 'abdomen')
  def_ab_abduct = model.find('default', 'abduct_abdomen')
  def_ab_extend = model.find('default', 'extend_abdomen')
  for abdomen in abdomens:
    for joint in abdomen.get_children('joint'):
      if 'extend' in joint.name:
        joint.dclass = def_ab_extend
      else:
        joint.dclass = def_ab_abduct
      joint.axis = None
      joint.range = None

  print('Haltere joints.')
  halteres = [b for b in thorax.find_all('body') if 'haltere' in b.name]
  for haltere in halteres:
    if 'left' in haltere.name:
      rotate_y = np.array((0., 0., 1., 0.))
      change_body_frame(haltere, haltere.pos, mul_quat(haltere.quat, rotate_y))
    for joint in haltere.get_children('joint'):
      joint.pos = None
      joint.axis = None
      joint.range = None
    haltere.childclass = model.find('default', 'haltere')

  print('Antennae joints.')
  for antenna in antennae:
    for joint in antenna.get_children('joint'):
      for axis in ax_names.values():
        if axis in joint.name:
          joint.dclass = model.find('default', 'antenna_' + axis)
          joint.axis = None
          joint.range = None
          joint.pos = None

  print('Leg joints.')
  # abduct_femur -> twist_femur
  for leg in legs:
    for joint in leg.find_all('joint'):
      if 'abduct_femur' in joint.name:
        joint.name = joint.name.replace('abduct_femur', 'twist_femur')

  for leg in legs:
    leg.childclass = model.find('default', 'leg')
    parent = leg
    while parent:
      # set joint properties
      for joint in parent.get_children('joint'):
        joint.pos = None
        joint.range = None
        joint.axis = None
        def_name = joint.name.replace('_left', '').replace('_right', '')
        if not model.find('default', def_name):
          def_name = def_name[:-3]  # remove _TX
          if not model.find('default', def_name):
            def_name = def_name[:-1]  # remove tarsus index
            if not model.find('default', def_name):
              raise ValueError('Default class not found for joint.')
        joint.dclass = model.find('default', def_name)

      # remove some unit quaternions while we're here
      if parent.quat is not None and parent.quat[0] == 1:
        parent.quat = None

      children = parent.get_children('body')
      if children:
        parent = children[0]
      else:
        parent = None

  print('Abdomen tendons.')
  abdomen_tendons = {'abduct_abdomen': None, 'extend_abdomen': None}
  for name in abdomen_tendons:
    tendon = model.tendon.add('fixed', name=name)
    abdomen_tendons[name] = tendon
    for abdomen in abdomens:
      for joint in abdomen.get_children('joint'):
        if name in joint.name:
          tendon.add('joint', joint=joint, coef=1)

  print('Tarsus tendons.')
  tarsus_tendons = {}
  for leg in legs:
    # Tendon couples tarsi starting from tarsus2.
    parent = leg.find('body', leg.name.replace('coxa', 'tarsus2'))
    name = 'extend_' + leg.name.replace('coxa', 'tarsus2')
    tendon = model.tendon.add('fixed', name=name, dclass='extend_tarsus')
    tarsus_tendons[name] = tendon
    while parent.get_children('joint'):
      joint = parent.get_children('joint')[0]
      if 'tarsus2' in joint.name:
        coef = 1
      else:
        coef = .5
      tendon.add('joint', joint=joint, coef=coef)
      if parent.get_children('body'):
        parent = parent.get_children('body')[0]
      else:
        break

  print('Add "general" actuators.')
  for joint in model.find_all('joint'):
    if 'free' in joint.name or 'haltere' in joint.name:
      continue
    if 'abdomen' in joint.name:
      for tendon_name in abdomen_tendons:
        if joint.name == tendon_name:
          dclass = model.find('default', tendon_name)
          num_joints = len(abdomen_tendons[tendon_name].get_children('joint'))
          model.actuator.add('general', name=tendon_name,
                             tendon=abdomen_tendons[tendon_name], dclass=dclass,
                             ctrlrange=num_joints*dclass.joint.range)
      continue
    if 'tarsus' in joint.name:
      if 'tarsus2' in joint.name:
        name = joint.name
        tendon = model.find('tendon', name)
        if tendon is not None:
          trange = np.array((0.0, 0.0))
          for tjoint in tendon.get_children('joint'):
            trange += tjoint.joint.dclass.joint.range * tjoint.coef
          model.actuator.add('general', name=name,
                             tendon=tendon, dclass=tendon.dclass,
                             ctrlrange=trange)
          continue
      elif 'abduct_tarsus' not in joint.name and 'tarsus_' not in joint.name:
        continue
    dclass = joint.dclass
    parent = joint.parent
    if dclass is None:
      while parent.childclass is None:
        parent = parent.parent
      dclass = parent.childclass
    if (
        'twist' in joint.name
        or 'abduct' in joint.name
        or 'extend' in joint.name
    ):
      if joint.range is not None:
        jrange = joint.range
      elif dclass.joint.range is not None:
        jrange = dclass.joint.range
      else:
        jrange = dclass.parent.joint.range
      # if 'twist_coxa_T2' in joint.name:
      #   import ipdb; ipdb.set_trace()
      if ('coxa' in joint.name or 'femur' in joint.name or
          'tibia' in joint.name or 'tarsus' in joint.name):
        assert (dclass.joint.range is not None or
                dclass.parent.joint.range is not None)
        actrange = None
      else:
        actrange = jrange
      model.actuator.add('general', name=joint.name, joint=joint,
                         dclass=dclass, ctrlrange=actrange)
    else:  # wing joints
      model.actuator.add('general', name=joint.name, joint=joint,
                         dclass=dclass)

  print('Add "adhesion" actuators.')
  for body in model.find_all('body'):
    if 'claw' in body.name:
      model.actuator.add('adhesion', name='adhere_'+body.name,
                         body=body.name, dclass='adhesion_claw')
    if 'labrum' in body.name:
      model.actuator.add('adhesion', name='adhere_'+body.name,
                         body=body.name, dclass='adhesion_labrum')

  print('Remove abduction for tarsi and tibia.')
  for actuator in model.find_all('actuator'):
    if 'abduct_tarsus' in actuator.name or 'abduct_tibia' in actuator.name:
      actuator.remove()
  for joint in model.find_all('joint'):
    if 'abduct_tarsus' in joint.name or 'abduct_tibia' in joint.name:
      joint.remove()
  for deflt in model.find_all('default'):
    if deflt.dclass is not None:
      if 'abduct_tarsus' in deflt.dclass or 'abduct_tibia' in deflt.dclass:
        deflt.remove()

  print('Recompile.')
  physics = mjcf.Physics.from_mjcf_model(model)

  print('Print qpos0 contacting bodies:')
  exclude_pairs = []
  for con in physics.data.contact:
    body1 = physics.model.id2name(physics.model.geom_bodyid[con.geom1], 'body')
    body2 = physics.model.id2name(physics.model.geom_bodyid[con.geom2], 'body')
    if 'coxa' in body1 or 'coxa' in body2:
      continue
    print(f'  {body1} and {body2} are in contact.')

  print('Exclude contacts')
  # Wing-abdomen.
  for i in range(len(abdomens) - 2):
    exclude_pairs.append((abdomens[i], abdomens[i+2]))
  exclude_pairs.append((wings[0], abdomens[0]))
  exclude_pairs.append((wings[0], abdomens[1]))
  exclude_pairs.append((wings[0], abdomens[2]))
  exclude_pairs.append((wings[1], abdomens[0]))
  exclude_pairs.append((wings[1], abdomens[1]))
  exclude_pairs.append((wings[1], abdomens[2]))
  # Wing-wing.
  exclude_pairs.append((wings[0], wings[1]))
  # Coxa-coxa, coxa-femur, femur-femur.
  for left_coxa in legs:
    if 'right' in left_coxa.name:
      continue
    right_coxa_name = left_coxa.name.replace('left', 'right')
    right_coxa = left_coxa.parent.find('body', right_coxa_name)
    left_femur = left_coxa.get_children('body')[0]
    right_femur = right_coxa.get_children('body')[0]
    exclude_pairs.append((left_coxa, right_coxa))
    exclude_pairs.append((left_coxa, right_femur))
    exclude_pairs.append((left_femur, right_coxa))
    exclude_pairs.append((left_femur, right_femur))
  # rostrum-labrum.
  for rostrum in head.find_all('body'):
    if 'rostrum' in rostrum.name:
      for labrum in head.find_all('body'):
        if 'labrum' in labrum.name:
          exclude_pairs.append((rostrum, labrum))

  for pair in exclude_pairs:
    model.contact.add('exclude', body1=pair[0], body2=pair[1],
                      name=pair[0].name + '_' + pair[1].name)

  print('Re-center thorax CoM at origin, again.')
  offset = physics.bind(thorax).xipos
  offset[1] = 0.0
  thorax.pos = np.array((0., 0., 0.))
  change_body_frame(thorax, offset, np.array((1., 0., 0., 0.)))
  if thorax.quat[0] == 1:
    thorax.quat = None
  thorax.pos = None

  print('Add sensors.')
  thorax.insert('site', 1, name='thorax')
  angle = -_YAW_AXIS_PITCH
  thorax.insert('site', 2, name='hover_up_dir',
                quat=np.array([np.cos(angle/2), 0, np.sin(angle/2), 0]),
                pos=(0.02625, 0, 0.02625))
  model.sensor.add('accelerometer', name='accelerometer', site='thorax')
  model.sensor.add('gyro', name='gyro', site='thorax')
  model.sensor.add('velocimeter', name='velocimeter', site='thorax')
  touch_sites = []
  force_sites = []
  for leg in legs:
    for body in leg.find_all('body'):
      if 'claw' in body.name:
        for geom in body.get_children('geom'):
          if 'collision' in geom.name:
            site = body.add('site', name=body.name,
                            dclass='adhesion-collision',
                            fromto=geom.fromto,
                            size=geom.size*1.1)
            touch_sites.append(site)
      if 'tarsus_' in body.name:
        site = body.insert('site', -1, name=body.name)
        force_sites.append(site)
  for site in force_sites:
    model.sensor.add('force', name='force_' + site.name, site=site)
  for site in touch_sites:
    model.sensor.add('touch', name='touch_' + site.name, site=site)

  print('Add thorax light and cameras.')
  thorax.insert('light', 1, name='tracking', mode='trackcom', pos=(0, 0, 1))
  thorax.insert('light', 1, name='left', mode='trackcom', pos=(0, 1, 1),
                dir=(0, -1, -1), diffuse=(0.3, 0.3, 0.3))
  thorax.insert('light', 1, name='right', mode='trackcom', pos=(0, -1, 1),
                dir=(0, 1, -1), diffuse=(0.3, 0.3, 0.3))
  thorax.insert(
      'camera',
      2,
      name='track1',
      mode='trackcom',
      pos=(0.6, 0.6, 0.22),
      quat=(0.31246, 0.22094, 0.5334, 0.75434))
  thorax.insert(
      'camera',
      3,
      name='track2',
      mode='trackcom',
      pos=(0, -1.1, 0.1),
      quat=(0.70711, 0.70711, 0, 0))
  thorax.insert(
      'camera',
      4,
      name='track3',
      mode='trackcom',
      pos=(-0.9, -0.9, 0.9),
      quat=(0.82047, 0.42471, -0.17592, -0.33985))
  thorax.insert(
      'camera',
      5,
      name='back',
      mode='track',
      pos=(-0.462, 0, 0.297),
      xyaxes=(0, -1, 0, 0.707, 0, 0.707))
  thorax.insert(
      'camera',
      6,
      name='side',
      mode='track',
      pos=(-0.055, 0.424, -0.064),
      xyaxes=(-1, 0, 0, 0, 0, 1))
  thorax.insert(
      'camera',
      7,
      name='bottom',
      mode='track',
      pos=(0.01, 0, -0.516),
      xyaxes=(0, 1, 0, .991, 0, 0.136))

  print('Check masses (values in milligrams)')
  def print_mass(name, current, emp):
    print(f'  {name:10}\t{current:.4g}\t\t{emp:.4g}\t\t{emp/current:.4g}')
  def print_masses():
    print('  part\t\tmodeled\t\tempirical\tratio')
    thorax_mass = 1e3*physics.named.model.body_mass['thorax']
    print_mass('thorax', thorax_mass, _MASS['thorax'])
    for part in ['head', 'abdomen']:
      part_mass = 1e3*physics.named.model.body_subtreemass[part]
      print_mass(part, part_mass, _MASS[part])
    wing_mass = 0
    for wing in ['wing_left', 'wing_right']:
      wing_mass += 1e3*physics.named.model.body_subtreemass[wing]
    print_mass('wing', wing_mass/2, _MASS['wing'])
    leg_mass = 0
    for leg in legs:
      leg_mass += 1e3*physics.named.model.body_subtreemass[leg.name]
    print_mass('leg', leg_mass/6, _MASS['leg'])

  print_masses()
  total_mass_model = 1e3*physics.named.model.body_subtreemass['thorax']
  total_mass_emp = (_MASS['head'] + _MASS['thorax'] + _MASS['abdomen'] +
                    6 * _MASS['leg'] + 2 * _MASS['wing'])
  print(f'  Total Mass\t{total_mass_model:.4g}\t\t{total_mass_emp:.4g}')

  print('Change order: axis_body -> body_axis.')
  elements = model.find_all('actuator') + model.find_all('joint')
  for element in elements:
    if 'adhere' in element.name:
      continue
    parts_nested = [s.split('-') for s in element.name.split('_')]
    parts = list(itertools.chain.from_iterable(parts_nested))
    if len(parts) < 2:
      continue
    order = list(range(len(parts)))
    order[0] = 1
    order[1] = 0
    element.name = '_'.join([parts[i] for i in order])

  print('Remove unnecessary "extend"s.')
  for joint in model.find_all('joint'):
    if '_extend' in joint.name:
      joint.name = joint.name.replace('_extend', '')
  for actuator in model.find_all('actuator'):
    if '_extend' in actuator.name:
      actuator.name = actuator.name.replace('_extend', '')
  for tendon in model.find_all('tendon'):
    if 'extend_' in tendon.name:
      tendon.name = tendon.name.replace('extend_', '')

  def shorten_names():
    print('Shortening actuator names')
    elements = model.find_all('actuator') + model.find_all('joint')
    for element in elements:
      parts_nested = [s.split('-') for s in element.name.split('_')]
      parts = list(itertools.chain.from_iterable(parts_nested))
      parts = [part.replace('left', 'L') for part in parts]
      parts = [part.replace('right', 'R') for part in parts]
      parts = [part.replace('extend', 'ex') for part in parts]
      parts = [part.replace('abduct', 'ab') for part in parts]
      parts = [part.replace('twist', 'tw') for part in parts]
      parts = [part.replace('wing', 'wng') for part in parts]
      if len(element.name) > 10:
        parts = [part.replace('T1', '1') for part in parts]
        parts = [part.replace('T2', '2') for part in parts]
        parts = [part.replace('T3', '3') for part in parts]
        parts = [part.replace('tarsus', 'tars') for part in parts]
        parts = [part.replace('pitch', 'ptch') for part in parts]
        parts = [part.replace('femur', 'fem') for part in parts]
        parts = [part.replace('antenna', 'anten') for part in parts]
        # parts = [part.replace('haltere', 'halt') for part in parts]
        parts = [part.replace('haustellum', 'haust') for part in parts]
        # parts = [part.replace('labrum', 'labr') for part in parts]
      short_name = '_'.join(parts)
      short_name = short_name[0].lower() + short_name[1:]
      if short_name[-3].isdigit():
        short_name = short_name[:-3] + short_name[-1] + short_name[-3]
      element.name = short_name

  # name-shortening: currently unused
  if 'short' in FINAL_MODEL:
    shorten_names()

  print('== XML cleanup.')
  print('Remove class="/" using lxml.')
  xml_string = model.to_xml_string('float', precision=3, zero_threshold=1e-7)
  root = etree.XML(xml_string,
                   etree.XMLParser(remove_blank_text=True))
  default_elem = root.find('default')
  root.insert(3, default_elem[0])
  root.remove(default_elem)

  print('Remove hashes from filenames.')
  meshes = [mesh for mesh in root.find('asset').iter() if mesh.tag == 'mesh']
  for mesh in meshes:
    name, extension = mesh.get('file').split('.')
    mesh.set('file', '.'.join((name[:-41], extension)))

  print('Get string from lxml.')
  xml_string = etree.tostring(root, pretty_print=True)
  xml_string = xml_string.replace(b' class="/"', b'')

  print('Remove gravcomp="0".')
  xml_string = xml_string.replace(b' gravcomp="0"', b'')

  print('Insert spaces between top level elements.')
  lines = xml_string.splitlines()
  newlines = []
  for line in lines:
    newlines.append(line)
    if line.startswith(b'  <'):
      if line.startswith(b'  </') or line.endswith(b'/>'):
        newlines.append(b'')
  newlines.append(b'')
  xml_string = b'\n'.join(newlines)

  print(f'Save {FINAL_MODEL} to file.')
  with open(ASSET_RELPATH + '/../' + FINAL_MODEL, 'wb') as f:
    f.write(xml_string)

  print('Done.')


if __name__ == '__main__':
  app.run(main)
